package com.atguigu.gmall.realtime.dwd.db.split.function;

import com.alibaba.fastjson.JSONObject;
import com.atguigu.gmall.realtime.common.bean.TableProcessDwd;
import com.atguigu.gmall.realtime.common.util.JDBCUtil;
import org.apache.flink.api.common.state.BroadcastState;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReadOnlyBroadcastState;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction;
import org.apache.flink.util.Collector;

import java.sql.Connection;
import java.util.*;

public class BaseDbTableProcessFunction extends BroadcastProcessFunction<JSONObject, TableProcessDwd, Tuple2<JSONObject, TableProcessDwd>> {

    private MapStateDescriptor mapStateDescriptor;

    Map<String, TableProcessDwd> configMap = new HashMap<>();

    public BaseDbTableProcessFunction(MapStateDescriptor mapStateDescriptor) {
        this.mapStateDescriptor = mapStateDescriptor;
    }

    @Override
    public void open(Configuration parameters) throws Exception {
        Connection mysqlConnection = JDBCUtil.getMysqlConnection();
        String sql = "select * from gmall0221_config.table_process_dwd";
        List<TableProcessDwd> tableProcessDwdList = JDBCUtil.queryList(mysqlConnection, sql, TableProcessDwd.class, true);
        for (TableProcessDwd tableProcessDwd : tableProcessDwdList) {
            String sourceTable = tableProcessDwd.getSourceTable();
            String sourceType = tableProcessDwd.getSourceType();
            String key = getKey(sourceTable, sourceType);
            configMap.put(key, tableProcessDwd);
        }

        JDBCUtil.closeMysqlConnection(mysqlConnection);
    }

    private static String getKey(String sourceTable, String sourceType) {
        return sourceTable + ":" + sourceType;
    }

    @Override
    public void processElement(JSONObject jsonObj, BroadcastProcessFunction<JSONObject, TableProcessDwd, Tuple2<JSONObject, TableProcessDwd>>.ReadOnlyContext readOnlyContext, Collector<Tuple2<JSONObject, TableProcessDwd>> collector) throws Exception {
        String table = jsonObj.getString("table");
        String type = jsonObj.getString("type");
        String key = table + ":" + type;
        ReadOnlyBroadcastState<String, TableProcessDwd> broadcastState = readOnlyContext.getBroadcastState(mapStateDescriptor);
        TableProcessDwd tableProcessDwd = null;
        if ((tableProcessDwd = broadcastState.get("key"))!= null
        || (tableProcessDwd = configMap.get(key)) != null){
            JSONObject dataJSONObj = jsonObj.getJSONObject("data");
            String sinkColumns = tableProcessDwd.getSinkColumns();
            deleteNotNeedColumn(dataJSONObj, sinkColumns);
            Long ts = jsonObj.getLong("ts");
            dataJSONObj.put("ts", ts);
            collector.collect(Tuple2.of(dataJSONObj, tableProcessDwd));
        }
    }

    private void deleteNotNeedColumn(JSONObject dataJSONObj, String sinkColumns) {
        List<String> columnList = Arrays.asList(sinkColumns.split(","));
        Set<Map.Entry<String, Object>> entrySet = dataJSONObj.entrySet();
        entrySet.removeIf(entry -> !columnList.contains(entry.getKey()));
    }

    @Override
    public void processBroadcastElement(TableProcessDwd tp, BroadcastProcessFunction<JSONObject, TableProcessDwd, Tuple2<JSONObject, TableProcessDwd>>.Context context, Collector<Tuple2<JSONObject, TableProcessDwd>> collector) throws Exception {
        String op = tp.getOp();
        BroadcastState<String, TableProcessDwd> broadcastState = context.getBroadcastState(mapStateDescriptor);
        String sourceTable = tp.getSourceTable();
        String sourceType = tp.getSourceType();
        String key = sourceTable + ":" + sourceType;
        if ("d".equals(op)){
            broadcastState.remove(key);
            configMap.remove(key);
        } else {
            broadcastState.put(key, tp);
            configMap.put(key, tp);
        }
    }
}
